package org.bdware.analysis.dynamic;

import org.bdware.analysis.BasicBlock;
import org.bdware.analysis.BreadthFirstSearch;
import org.bdware.analysis.OpInfo;
import org.bdware.analysis.taint.*;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.tree.*;

import java.util.*;

public class NaiveDynamicTaintAnalysis extends BreadthFirstSearch<TaintResult, TaintBB> {

    TaintCFG cfg;
    TracedFile tf;
    public static boolean isDebug = false;
    int count = 0;
    String functionGlobel;
    public static HashMap<String, HashMap<Integer, Integer>> branchCount = new HashMap<>();

    public NaiveDynamicTaintAnalysis(TaintCFG cfg, TracedFile tf) {
        this.cfg = cfg;
        this.tf = tf;
        List<TaintBB> toAnalysis = new ArrayList<>();
        MethodNode mn = cfg.getMethodNode();

        String methodDesc = mn.desc;
        methodDesc = methodDesc.replaceAll("\\).*$", ")");
        int pos = 2;
        if (methodDesc.split(";").length == 3) pos = 1;
        // TODO add inputBlock!
        TaintBB b = (TaintBB) cfg.getBasicBlockAt(0);
        b.preResult = new TaintResult();
        int arg = cfg.argsLocal();
        b.preResult.frame.setLocal(arg, new TaintValue(1, cfg.taintBits.allocate("arg" + arg)));
        b.preResult.frame.setLocal(0, HeapObject.getRootObject());
        cfg.executeLocal();
        TaintResult.interpreter.setTaintBits(cfg.taintBits);
        functionGlobel = cfg.getMethodNode().name;
        branchCount.put(functionGlobel, new HashMap<Integer, Integer>());
        // local0=scriptfuncion, is not tainted;
        // local1=this, is not tainted;
        // local2=this, is not tainted;
        // NaiveTaintResult.printer.setLabelOrder(cfg.getLabelOrder());
        toAnalysis.add(b);
        b.setInList(true);
        setToAnalysis(toAnalysis);
        if (isDebug) {
            System.out.println("===Method:" + cfg.getMethodNode().name + cfg.getMethodNode().desc);
            System.out.println(
                    "===Local:"
                            + cfg.getMethodNode().maxLocals
                            + " "
                            + cfg.getMethodNode().maxStack);
        }
    }

    @Override
    public TaintResult execute(TaintBB t) {
        return t.forwardAnalysis();
    }

    // Current Block is done, merge sucResult to sucBlocks!
    @Override
    public Collection<TaintBB> getSuc(TaintBB t) {
        Set<BasicBlock> subBlock = new HashSet<>(); // bug ! clear() no suc
        if (t.list.size() > 0) {
            AbstractInsnNode insn = t.lastInsn();
            if (insn != null) {
                OpInfo info = null;
                if (insn.getOpcode() >= 0) info = OpInfo.ops[insn.getOpcode()];
                if (info == null) {
                    subBlock.addAll(cfg.getSucBlocks(t));
                } else if (info.canThrow()) {
                    callCount(insn);
                    subBlock.add(cfg.getBasicBlockAt(t.blockID + 1));
                } else if (OpInfo.ops[insn.getOpcode()].canBranch()) {
                    subBlock.clear();
                    int block = handleBranchCase(subBlock, insn, t);
                    subBlock.add(cfg.getBasicBlockAt(block));
                } else if (OpInfo.ops[insn.getOpcode()].canSwitch()) {
                    subBlock.add(cfg.getBasicBlockAt(t.blockID + 1));
                } else {
                    subBlock.addAll(cfg.getSucBlocks(t));
                }
            }
        }
        Set<TaintBB> ret = new HashSet<>();
        // System.out.print("[NaiveDynamicTaintAnalysis] nextBB:");
        for (BasicBlock bb : subBlock) {
            // System.out.print(" " + bb.blockID);
            TaintBB ntbb = (TaintBB) bb;
            ntbb.preResult.mergeResult(t.sucResult);
            ret.add(ntbb);
        }
        // System.out.println();
        return ret;
    }

    private void callCount(AbstractInsnNode insn) {
        if (insn instanceof InvokeDynamicInsnNode) {
            Object invoke = ((InvokeDynamicInsnNode) insn).bsmArgs[0];
            String functionName = ((InvokeDynamicInsnNode) insn).name;
            if (functionName.contains("traceif")) {
                traceIfNum = (int) invoke;
                if (branchCount.get(functionGlobel).containsKey(traceIfNum)) {
                    branchCount
                            .get(functionGlobel)
                            .put(traceIfNum, branchCount.get(functionGlobel).get(traceIfNum) + 1);
                } else {
                    branchCount.get(functionGlobel).put(traceIfNum, 1);
                }
            }
        }
    }

    int traceIfNum = 0;

    private int handleBranchCase(Set<BasicBlock> subBlock, AbstractInsnNode insn, TaintBB t) {
        int blockid = 0;
        switch (insn.getOpcode()) {
            case Opcodes.IFNE: // succeeds if and only if value ≠ 0
                int ifneCount = branchCount.get(functionGlobel).get(traceIfNum);
                if (tf.trans.get(0).tmToVal.get(traceIfNum).get(ifneCount) != 0) { // test first
                    if (insn instanceof JumpInsnNode) {
                        LabelNode jump = ((JumpInsnNode) insn).label;
                        blockid = cfg.getBasicBlockByLabel(jump.getLabel()).blockID;
                    }
                } else {
                    blockid = t.blockID + 1;
                }
                return blockid;
            case Opcodes.IFEQ: // succeeds if and only if value = 0
                int ifeqCount = branchCount.get(functionGlobel).get(traceIfNum);
                if (tf.trans.get(0).tmToVal.get(traceIfNum).get(ifeqCount) == 0) {
                    if (insn instanceof JumpInsnNode) {
                        LabelNode jump = ((JumpInsnNode) insn).label;
                        blockid = cfg.getBasicBlockByLabel(jump.getLabel()).blockID;
                    }
                } else {
                    blockid = t.blockID + 1;
                }
                return blockid;
            case Opcodes.IFGE: // succeeds if and only if value ≥ 0
                int ifgeCount = branchCount.get(functionGlobel).get(traceIfNum);
                if (tf.trans.get(0).tmToVal.get(traceIfNum).get(ifgeCount) >= 0) {
                    if (insn instanceof JumpInsnNode) {
                        LabelNode jump = ((JumpInsnNode) insn).label;
                        blockid = cfg.getBasicBlockByLabel(jump.getLabel()).blockID;
                    }
                } else {
                    blockid = t.blockID + 1;
                }
                return blockid;
            case Opcodes.IFLE: // succeeds if and only if value ≤ 0
                int ifleCount = branchCount.get(functionGlobel).get(traceIfNum);
                if (tf.trans.get(0).tmToVal.get(traceIfNum).get(ifleCount) <= 0) {
                    if (insn instanceof JumpInsnNode) {
                        LabelNode jump = ((JumpInsnNode) insn).label;
                        blockid = cfg.getBasicBlockByLabel(jump.getLabel()).blockID;
                    }
                } else {
                    blockid = t.blockID + 1;
                }
                return blockid;
            case Opcodes.IFLT: // succeeds if and only if value < 0
                int ifltCount = branchCount.get(functionGlobel).get(traceIfNum);
                if (tf.trans.get(0).tmToVal.get(traceIfNum).get(ifltCount) < 0) {
                    if (insn instanceof JumpInsnNode) {
                        LabelNode jump = ((JumpInsnNode) insn).label;
                        blockid = cfg.getBasicBlockByLabel(jump.getLabel()).blockID;
                    }
                } else {
                    blockid = t.blockID + 1;
                }
                return blockid;
            case Opcodes.IFGT: // succeeds if and only if value > 0
                int ifgtCount = branchCount.get(functionGlobel).get(traceIfNum);
                if (tf.trans.get(0).tmToVal.get(traceIfNum).get(ifgtCount) < 0) {
                    if (insn instanceof JumpInsnNode) {
                        LabelNode jump = ((JumpInsnNode) insn).label;
                        blockid = cfg.getBasicBlockByLabel(jump.getLabel()).blockID;
                    }
                } else {
                    blockid = t.blockID + 1;
                }
                return blockid;
            case Opcodes.IF_ACMPEQ: // succeeds if and only if value1 = value2
            case Opcodes.IF_ACMPNE: // succeeds if and only if value1 ≠ value2
            case Opcodes.IF_ICMPEQ: // succeeds if and only if value1 = value2
            case Opcodes.IF_ICMPGE: // succeeds if and only if value1 ≠ value2
            case Opcodes.IF_ICMPGT: // succeeds if and only if value1 > value2
            case Opcodes.IF_ICMPLE: // succeeds if and only if value1 ≤ value2
            case Opcodes.IF_ICMPLT: // succeeds if and only if value1 < value2
            case Opcodes.IFNONNULL:
            case Opcodes.TABLESWITCH:
            case Opcodes.LOOKUPSWITCH:
            default:
                break;
            case Opcodes.GOTO:
                if (insn instanceof JumpInsnNode) {
                    LabelNode jump = ((JumpInsnNode) insn).label;
                    blockid = cfg.getBasicBlockByLabel(jump.getLabel()).blockID;
                }
        }
        return blockid;
    }
}
